In [ ]:
NAME = "Caden Truelick"

Fruits and Vegetables Image Classification Project

For this project, I am going to be using a dataset consisting of several thousand images of various fruits and vegetables and I will create machine learning models that can predict what type of fruit or vegetable an image is depicting. Some of the methods that I will use are listed below:

Unsupervised Models

  • K-Means Clustering Analysis
  • Dimensionality Reduction using Principal Component Analysis

Supervised Models

  • Artificial Neural Network Machine Learning Model
  • Transfer Learning Model

Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import plotly.express as px
import seaborn as sns

import tensorflow as tf

#for opening/extracting the zipfile
import zipfile

#for loading in custom data
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
/usr/local/lib/python3.8/dist-packages/scipy/__init__.py:138: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.4)
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion} is required for this version of "

RUN THE FOLLOWING CELL ONCE TO UNZIP THE FILE

Command that extracts all the data from the Zip File into data folder
In [ ]:
#where r is read-mode
with zipfile.ZipFile('fruit_data.zip', 'r') as zip_ref:
    zip_ref.extractall('data')

Loading/Cleaning the Dataset

Make sure all the files are valid

In [2]:
from pathlib import Path
import imghdr

#directories of each folder
data_dir_train = 'data/train/'
data_dir_test = 'data/test/'
data_dir_valid = 'data/validation/'

#image extensions
image_extensions = ['.png', '.jpg', '.jpeg']
img_type_accepted_by_tf = ['bmp', 'gif', 'jpeg', 'png']

#check files to make sure it is the correct extension
def check_files(directory):
    for filepath in Path(directory).glob('**/*'):
        if filepath.suffix.lower() in image_extensions:
            img_type = imghdr.what(filepath)              #pulls the img extension from the file
            if img_type is None:
                print(f"{filepath} is not an image")
            elif img_type not in img_type_accepted_by_tf:
                print(f"{filepath} is a {img_type}, not accepted by TensorFlow")
    print(f'Completed {directory}') 
    
    
check_files(data_dir_train)
check_files(data_dir_test)
check_files(data_dir_valid)
Completed data/train/
Completed data/test/
Completed data/validation/

Pull image data into train, validation, and test datasets via their respective folder

In [3]:
train = tf.keras.utils.image_dataset_from_directory(data_dir_train,
                                                    color_mode = 'rgb',
                                                    batch_size = 36,
                                                    image_size = (256,256),
                                                    shuffle=True,
                                                    crop_to_aspect_ratio = True)

valid = tf.keras.utils.image_dataset_from_directory(data_dir_valid,
                                                    color_mode = 'rgb',
                                                    batch_size = 36,
                                                    image_size = (256,256),
                                                    shuffle=False,
                                                    crop_to_aspect_ratio = True)

test = tf.keras.utils.image_dataset_from_directory(data_dir_test,
                                                   color_mode = 'rgb',
                                                   batch_size = 36,
                                                   image_size = (256,256),
                                                   shuffle=False,
                                                   crop_to_aspect_ratio = True)
Found 3115 files belonging to 36 classes.
Found 351 files belonging to 36 classes.
Found 359 files belonging to 36 classes.
In [4]:
class_names = train.class_names
print(class_names)
['apple', 'banana', 'beetroot', 'bell pepper', 'cabbage', 'capsicum', 'carrot', 'cauliflower', 'chilli pepper', 'corn', 'cucumber', 'eggplant', 'garlic', 'ginger', 'grapes', 'jalepeno', 'kiwi', 'lemon', 'lettuce', 'mango', 'onion', 'orange', 'paprika', 'pear', 'peas', 'pineapple', 'pomegranate', 'potato', 'raddish', 'soy beans', 'spinach', 'sweetcorn', 'sweetpotato', 'tomato', 'turnip', 'watermelon']

Plot Some Images

def plotImages(data):

The plotImages function takes a dataset input data. The function takes a batch from the dataset and plots all of the images with the corresponding label.

In [5]:
def plotImages(data):
    plt.figure(figsize=(15,15))
    for images, labels in data.take(1):
        for i in range(36): 
            ax = plt.subplot(6, 6, i+1)
            plt.imshow(images[i].numpy().astype('uint8'))
            plt.title(class_names[labels[i]])
            plt.axis('off')
In [6]:
plotImages(train)

Unsupervised Learning

In [7]:
from sklearn.cluster import KMeans #for clustering analysis
from sklearn.decomposition import PCA #for dimensionality reduction

#flatten the image data
flat_images = []
for images, _ in train:
    flat_images.extend(images.numpy().reshape(images.shape[0], -1))
        

flat_images = np.array(flat_images)

K-means clustering analysis

In [25]:
kmeans = KMeans(n_clusters=36, random_state=123, n_init=10)
cluster_labels = kmeans.fit_predict(flat_images)
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525b80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525b80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525b80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525b80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525940>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7f62cc525b80>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/usr/local/lib/python3.8/dist-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'

Dimensionality Reduction using PCA

In [9]:
pca = PCA(n_components=2)
reduced_data = pca.fit_transform(flat_images)

Visualize the clusters in 2D

In [10]:
df_2d = pd.DataFrame(data=reduced_data, columns=['PC1', 'PC2'])
df_2d['Cluster'] = cluster_labels

fig = px.scatter(df_2d, x='PC1', y='PC2', color='Cluster', opacity=0.8, size_max=10)
fig.update_layout(title='2D K-means Clustering')
fig.show()
In [11]:
pca_3d = PCA(n_components=3)
reduced_data_3d = pca_3d.fit_transform(flat_images)

df_3d = pd.DataFrame(data=reduced_data_3d, columns=['PC1', 'PC2', 'PC3'])
df_3d['Cluster'] = cluster_labels

fig = px.scatter_3d(df_3d, x='PC1', y='PC2', z='PC3', color='Cluster', opacity=0.8, size_max=10)
fig.update_layout(scene=dict(xaxis_title='PC 1', yaxis_title='PC 2', zaxis_title='PC 3'),
                  title='3D K-means Clustering')
fig.show()
In [12]:
#choose a cluster to display (minimum 5 images)
cluster_to_display = 0

for i in cluster_labels:
    images_in_cluster = np.where(cluster_labels == i)[0]
    num_images = len(images_in_cluster)
    if num_images >= 5:
        cluster_to_display = i
        break

#get the indicies in the cluster
images_in_cluster = np.where(cluster_labels == cluster_to_display)[0]

#display images from chosen cluster
plt.figure(figsize=(15, 5))
for i in range(5):  #displaying first 5 images
    image_idx = images_in_cluster[i]
    plt.subplot(1, 5, i + 1)
    
    #reshape back to original
    original_shape = train.element_spec[0].shape[1:]  # Get the shape of original images
    image = flat_images[image_idx].reshape(original_shape).astype('uint8')
    
    plt.imshow(image)
    plt.title(f'Cluster {cluster_to_display}')
    plt.axis('off')

plt.show()

Supervised Learning

Building the Neural Network Model

In [13]:
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense, Reshape, Rescaling

model = tf.keras.Sequential()

model.add(Rescaling(1./255, input_shape=(256,256,3)))

model.add(Conv2D(16, (3, 3), 1, activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(32, (3, 3), 1, activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), 1, activation='relu'))
model.add(MaxPooling2D((2, 2)))

model.add(Flatten())

model.add(Dense(256, activation='relu'))
model.add(Dense(36, activation='softmax')) #softmax because this is a multiclass classification
               
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 256, 256, 3)       0         
                                                                 
 conv2d (Conv2D)             (None, 254, 254, 16)      448       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 127, 127, 16)     0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 125, 125, 32)      4640      
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 62, 62, 32)       0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 60, 60, 64)        18496     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 30, 30, 64)       0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 57600)             0         
                                                                 
 dense (Dense)               (None, 256)               14745856  
                                                                 
 dense_1 (Dense)             (None, 36)                9252      
                                                                 
=================================================================
Total params: 14,778,692
Trainable params: 14,778,692
Non-trainable params: 0
_________________________________________________________________
In [14]:
model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy', #labels are provided as integers so we use sparse
              metrics=['accuracy'])
In [15]:
history = model.fit(train, epochs=10, batch_size=36, validation_data=valid)
test_loss, test_acc = model.evaluate(valid)
Epoch 1/10
87/87 [==============================] - 47s 512ms/step - loss: 3.4691 - accuracy: 0.0867 - val_loss: 2.4662 - val_accuracy: 0.2906
Epoch 2/10
87/87 [==============================] - 45s 503ms/step - loss: 2.5415 - accuracy: 0.2616 - val_loss: 1.6734 - val_accuracy: 0.5356
Epoch 3/10
87/87 [==============================] - 45s 499ms/step - loss: 2.0588 - accuracy: 0.3897 - val_loss: 1.2192 - val_accuracy: 0.6838
Epoch 4/10
87/87 [==============================] - 44s 491ms/step - loss: 1.4797 - accuracy: 0.5695 - val_loss: 0.6509 - val_accuracy: 0.8234
Epoch 5/10
87/87 [==============================] - 48s 540ms/step - loss: 0.8344 - accuracy: 0.7634 - val_loss: 0.4573 - val_accuracy: 0.8860
Epoch 6/10
87/87 [==============================] - 45s 502ms/step - loss: 0.4350 - accuracy: 0.8742 - val_loss: 0.3471 - val_accuracy: 0.9288
Epoch 7/10
87/87 [==============================] - 49s 544ms/step - loss: 0.2195 - accuracy: 0.9445 - val_loss: 0.2950 - val_accuracy: 0.9316
Epoch 8/10
87/87 [==============================] - 45s 501ms/step - loss: 0.1300 - accuracy: 0.9756 - val_loss: 0.3024 - val_accuracy: 0.9573
Epoch 9/10
87/87 [==============================] - 46s 509ms/step - loss: 0.1057 - accuracy: 0.9798 - val_loss: 0.2955 - val_accuracy: 0.9487
Epoch 10/10
87/87 [==============================] - 45s 503ms/step - loss: 0.0890 - accuracy: 0.9833 - val_loss: 0.2631 - val_accuracy: 0.9544
10/10 [==============================] - 2s 188ms/step - loss: 0.2631 - accuracy: 0.9544

Plot the History

Below are plots of the training loss and accuracy then the validation loss and accuracy, respectively.

In [16]:
#adapted from hw5

df = pd.DataFrame(history.history)

fig, ax1 = plt.subplots(2,1,figsize=(10,8))

ax1[0].plot(df['loss'], label='Training Loss')
ax1[0].plot(df['val_loss'], label='Validation Loss')
ax1[0].set_xlabel('Epoch')
ax1[0].legend()


ax1[1].plot(df['accuracy'], label='Training Accuracy')
ax1[1].plot(df['val_accuracy'], label='Validation Accuracy')
ax1[1].set_xlabel('Epoch')
ax1[1].legend()
Out[16]:
<matplotlib.legend.Legend at 0x7f62ec4eeb20>

Predictions on the Test Set

In [17]:
pred = model.predict(test)
y_pred = np.argmax(pred, axis=1)
print(y_pred)
[ 0  0  0  0 12  5  0  0  0  2  1  1  0  1  1  1  1  1 15  2  2  2  2  2
  2  2  2  2  2  3  3  3  3  3  3  3  5  3  3  4  4  4  4  4  4  4  4  4
  4  5  5  5  5  5  3  5  5  5  5  6  6  6  6  6  6  6  6  6  6  7  7  7
  7  7  7  7  7  7  7  8  8 26  8  8  8  8  8  8  8  9  9 31  9  9  9  9
  9 31  9 10 10 10 10 10 10 10 10 10 10 11 11 11 11 11 11 11 11 11 11 12
 12 12 12 12 12 12 12 12 12 13 13 13 13 13 13 13 13 13 13 14 14 14 14 14
 14 14 14 14 14 15 15 15 15 15 15 15 15 15 15 16 16 16 16 16 16 16 16 16
 16 17 17 17 17 17 17 17 17 17 17 18 18 18 18 18 18 18 18 18 18 19 19 19
 19 19 19 19 19 19 19 20 20 20 20 20 20 20 20 20 20 21 21 21 21 21 21 21
 21 21 21 22 22 22 22 22 22 22 22 22 22 23 23 23 23 23 23 23 23 23 23 24
 24 24 24 24 24 24 24 24 24 25 25 25 25 25 25 25 25 25 25 26 26 26 26 26
 26 26 26 26 26 27 27 20 27 27 27 27 27 25 27 28 28 28 28 28 28 28 28 28
 28 29 29 29 29 29 29 29 29 29 29 30 30 30 30 30 30 30 30 30 30 31 31 31
 31 31  9 31 31 31  9 32 32 32 32 32  6 32 29 32 32 33 33 33 33 33 33 33
 33 33 33 34 34 34 34 34 34 34 34 34 34 35 35 35 35 35 35 35 35 35 35]

Plot Results

Confusion Matrix

In [18]:
from sklearn.metrics import confusion_matrix

#get true labels from test dataset
true_labels = []
for _, labels in test:
    true_labels.extend(labels.numpy())

conf_matrix = confusion_matrix(true_labels, y_pred)

plt.figure(figsize=(15, 15))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
In [19]:
def plot_images_with_labels(data, true_labels, predicted_labels, class_names):
    plt.figure(figsize=(15, 20))
    unique_classes = np.unique(true_labels)
    
    for i, class_label in enumerate(unique_classes):
        # Find an image with the true label from the current class
        true_label_idx = np.where(true_labels == class_label)[0][0]
        
        # Get the image tensor from the dataset
        image_tensor = None
        for images, labels in data:
            if class_label in labels.numpy():
                image_tensor = images[labels.numpy() == class_label][0]
                break
        
        # Find an image with the predicted label from the current class
        predicted_label_idx = np.where(predicted_labels == class_label)[0][0]
        
        # Plot the image with true and predicted labels
        plt.subplot(6, 6, i + 1)
        plt.imshow(image_tensor.numpy().astype('uint8'))
        true_class = class_names[true_labels[true_label_idx]]
        predicted_class = class_names[predicted_labels[true_label_idx]]  # Corrected index
        plt.title(f'True: {true_class}\nPredicted: {predicted_class}')
        plt.axis('off')

# Plot one image per class with true and corrected predicted labels
plot_images_with_labels(test, true_labels, y_pred, class_names)
plt.show()
In [20]:
def plot_incorrect_predictions(data, true_labels, predicted_labels, class_names):
    plt.figure(figsize=(15, 15))
    incorrect_indices = np.where(true_labels != predicted_labels)[0]

    for i, incorrect_idx in enumerate(incorrect_indices):
        image_tensor = None
        for images, labels in data:
            if predicted_labels[incorrect_idx] in labels.numpy():
                image_tensor = images[labels.numpy() == predicted_labels[incorrect_idx]][0]
                break

        plt.subplot(4, 5, i + 1)
        plt.imshow(image_tensor.numpy().astype('uint8'))
        true_class = class_names[true_labels[incorrect_idx]]
        predicted_class = class_names[predicted_labels[incorrect_idx]]
        plt.title(f'True: {true_class}\nPredicted: {predicted_class}')
        plt.axis('off')

# Plot incorrect predictions
plot_incorrect_predictions(test, true_labels, y_pred, class_names)
plt.show()
In [ ]: